import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torch.multiprocessing as mp
import torch.distributed as dist
import sys
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
from efficientnet_pytorch import EfficientNet

from torch.optim.lr_scheduler import StepLR, MultiStepLR, CosineAnnealingWarmRestarts

from senet.se_resnet import FineTuneSEResnet50
from utils.init import init_weights
from utils.transform import get_transform_for_train, get_transform_for_test
from utils.loss_function import LabelSmoothingCrossEntropy
from utils.utils import adjust_learning_rate, accuracy, cosine_anneal_schedule
from utils.save import save_checkpoint
from utils.cutmix import cutmix_data
# from train import train
# from validate import validate
from graph_rise.graph_regularization import get_images_info, graph_rise_loss

os.environ['CUDA_VISIBLE_DEVICES'] = "2, 3"

parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('--dataroot', default='/data/userdata/set100-80/annotated_images_448/test1', type=str)
parser.add_argument('--logs_dir', default='./weights_dir/efficientnet-b5/test1', type=str)
parser.add_argument('--weights_dir', default='./weights_dir/efficientnet-b5/test1', type=str)
parser.add_argument('--test_weights_path', default="")
parser.add_argument('--init_type',  default='', type=str)
parser.add_argument('--weight_decay', '--wd', default=5e-4, type=float)

parser.add_argument('--epochs', default=300, type=int)
parser.add_argument('--start_epochs', default=0, type=int)
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--test_batch_size', default=32, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--img_size', default=448, type=int)
parser.add_argument('--eval_epoch', default=1, type=int)
parser.add_argument('--nclass', default=113, type=int)
parser.add_argument('--multi_gpus', default=[0, 1, 2, 3])
parser.add_argument('--gpu_nums', default=1, type=int)
parser.add_argument('--resume', default=r"", type=str)
parser.add_argument('--milestones', default=[120, 220, 270])

parser.add_argument('--graph_reg', default=False)
parser.add_argument('--label_smooth', default=False)
parser.add_argument('--cutmix', default=False)
parser.add_argument('--mixup', default=False)
parser.add_argument('--cosine_decay', default=True)

parser.add_argument('--rank', default=0, type=int, help='node rank for distributed training')
parser.add_argument('--ngpus_per_node', default=2, type=int)
parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')

best_prec1 = 0


def main():
    print('Part1 : prepare for parameters <==> Begin')
    args = parser.parse_args()

    ngpus_per_node = args.ngpus_per_node
    print('ngpus_per_node:', ngpus_per_node)
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))


def train(args, train_loader, model, criterion, optimizer, epoch, name_list, name_dict, ngpus_per_node):
    # switch to train mode
    model.train()

    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.cuda(args.gpu, non_blocking=True)
        labels = labels.cuda(args.gpu, non_blocking=True)
        # cutmix
        if args.cutmix:
            inputs, labels_a, labels_b, lam = cutmix_data(inputs, labels)
            outputs = model(inputs)
            loss = criterion(outputs, labels_a) * lam + criterion(outputs, labels_b) * (1. - lam)
        # mixup
        elif args.mixup:
            inputs, labels_a, labels_b, lam = mixup_data(inputs, labels)
            outputs = model(inputs)
            loss = criterion(outputs, labels_a) * lam + criterion(outputs, labels_b) * (1. - lam)
        else:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        # graph-rise regularization
        if args.graph_reg:
            graph_loss = graph_rise_loss(outputs, labels, name_list, name_dict)
            loss = loss + graph_loss

        # measure accuracy and record loss
        prec1, prec3 = accuracy(outputs, labels, topk=(1, 3))  # this is metric on trainset

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 10 == 0 and args.rank % ngpus_per_node == 0:
            print('Train Epoch: {0} Step: {1}/{2} Loss {loss:.4f} Top1 {top1:.3f} Top3 {top3:.3f} LR {lr:.7f}'.format(
                epoch, i, len(train_loader), loss=loss.item(), top1=prec1[0], top3=prec3[0], lr=optimizer.param_groups[0]['lr']))

            logs_dir = args.logs_dir
            if not os.path.exists(logs_dir):
                os.mkdir(logs_dir)
            logs_file = os.path.join(logs_dir, 'log_train.txt')

            with open(logs_file, 'a') as f:
                f.write('Train Epoch: {0} Step: {1}/{2} Loss {loss:.4f} Top1 {top1:.3f} Top3 {top3:.3f} LR {lr:.7f}\n'.format(
                epoch, i, len(train_loader), loss=loss.item(), top1=prec1[0], top3=prec3[0], lr=optimizer.param_groups[0]['lr']))


def validate(args, val_loader, model, criterion, epoch, ngpus_per_node):
    prec1_list = []
    prec5_list = []
    model.eval()

    for i, (inputs, labels) in enumerate(val_loader):
        inputs = inputs.cuda()
        labels = labels.cuda()

        # compute output
        outputs = model(inputs)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs, labels, topk=(1, 5))
        prec1_list.append(prec1[0].item())
        prec5_list.append(prec5[0].item())

    top1_avg = np.mean(prec1_list)
    top5_avg = np.mean(prec5_list)
    if args.rank % ngpus_per_node == 0:
        print('Test Epoch: {} Top1 {:.3f}% Top5 {:.3f}%'.format(epoch, top1_avg, top5_avg))

        logs_dir = args.logs_dir
        if not os.path.exists(logs_dir):
            os.mkdir(logs_dir)
        logs_file = os.path.join(logs_dir, 'log_test.txt')

        with open(logs_file, 'a') as f:
            f.write('Test Epoch: {} Top1 {:.3f}% Top5 {:.3f}%\n'.format(epoch, top1_avg, top5_avg))
    return top1_avg


def main_worker(gpu, ngpus_per_node, args):
    global best_prec1
    args.gpu = gpu

    args.rank = args.rank * ngpus_per_node + gpu
    dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:23456', world_size=ngpus_per_node, rank=gpu)
    print('rank', args.rank, ' use multi-gpus...')

    name_list, name_dict = get_images_info()

    if args.rank % ngpus_per_node == 0:
        print('Part1 : prepare for parameters <==> Done')
        print('Part2 : Load Network  <==> Begin')
    model = EfficientNet.from_pretrained('efficientnet-b5', num_classes=args.nclass)
    cudnn.benchmark = True
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model.cuda(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    if args.label_smooth:
        criterion = LabelSmoothingCrossEntropy()
    else:
        criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
    if args.cosine_decay:
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=args.epochs)
    else:
        scheduler = MultiStepLR(optimizer, milestones=args.milestones, gamma=0.1)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            # 反序列化为python字典
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epochs = checkpoint['epoch']
            best_prec1 = checkpoint['prec1']
#            if args.gpu is not None:
#                # best_acc1 may be from a checkpoint from a different GPU
#                best_prec1 = best_prec1.to(args.gpu)
            # 加载模型、优化器参数，继续从断开的地方开始训练
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print('继续从epoch:{}开始训练，当前best_acc为:{:.3f}'.format(args.start_epochs, best_prec1))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.rank % ngpus_per_node == 0:
        print('Part2 : Load Network  <==> Done')
        print('Part3 : Load Dataset  <==> Begin')

    dataroot = os.path.abspath(args.dataroot)
    traindir = os.path.join(dataroot, 'train_images')
    testdir = os.path.join(dataroot, 'test_images')

    # ImageFolder
    # mean=[0.948078, 0.93855226, 0.9332005], var=[0.14589554, 0.17054074, 0.18254866]
    transform_train = get_transform_for_train(mean=[0.948078, 0.93855226, 0.9332005], var=[0.14589554, 0.17054074, 0.18254866])
    transform_test = get_transform_for_test(mean=[0.948078, 0.93855226, 0.9332005], var=[0.14589554, 0.17054074, 0.18254866])

    train_dataset = ImageFolder(traindir, transform=transform_train)
    test_dataset = ImageFolder(testdir, transform=transform_test)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True)
    test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, shuffle=False)

    # data loader
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, pin_memory=True, num_workers=16, sampler=train_sampler)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=16, sampler=test_sampler)

    if args.rank % ngpus_per_node == 0:
        print('Part3 : Load Dataset  <==> Done')
        print('Part4 : Train and Test  <==> Begin')

    for epoch in range(args.start_epochs, args.epochs):
        # adjust_learning_rate(args, optimizer, epoch, gamma=0.1)

        # train for one epoch
        train(args, train_loader, model, criterion, optimizer, epoch, name_list, name_dict, ngpus_per_node)

        # evaluate on validation set
        if epoch % args.eval_epoch == 0:
            prec1 = validate(args, test_loader, model, criterion, epoch, ngpus_per_node)

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            if args.rank % ngpus_per_node == 0:
                if not is_best:
                    print('Top1 Accuracy stay with {:.3f}'.format(best_prec1))
                else:   # save the best model
                    save_checkpoint(args, state_dict={
                        'epoch': epoch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'prec1': prec1,
                    })
                    print('Save the best model with accuracy {:.3f}'.format(best_prec1))
        scheduler.step()
    print('Part4 : Train and Test  <==> Done')



if __name__ == '__main__':
    main()